import time
import torch
import torch.optim as optim
import torch.utils.data
from tqdm.auto import tqdm
import gc
import pickle
from pathlib import Path
import numpy as np
import pandas as pd
from transformers import RobertaConfig, RobertaForMaskedLM
from transformers import pipeline
from transformers import AutoModelForMaskedLM, AutoTokenizer
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing

# Načítanie tokenizovaných dát
with open("/home/projects/home/PureBPE/28032024_purebpe", 'rb') as file:
    encodings = pickle.load(file)

# Funkcia pre dynamické maskovanie
def mlm(tensor):
    rand = torch.rand(tensor.shape)
    mask_arr = (rand < 0.15) & (tensor > 2)
    tensor[mask_arr] = 4  # Nahradenie čísel 4, <mask token>
    return tensor

#encodings['input_ids'] = mlm(encodings['labels'].clone())

gc.collect()

# Inicializácia modelu
new_model = False  # trénovanie nového modelu alebo dotrénovanie už existujúceho________________________
if new_model:
    config = RobertaConfig(
        vocab_size=50264,
        max_position_embeddings=258,
        hidden_size=576,
        num_attention_heads=12,
        num_hidden_layers=6,
        dropout=0.1,
        type_vocab_size=1
    )

    model = RobertaForMaskedLM(config)
    model.train()
else:
    model_path = '/home/projects/home/PureBPE/models/PureBPE_epoch_24_encodings_2' #Aktuálny model_________________________
    model = AutoModelForMaskedLM.from_pretrained(model_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = torch.nn.DataParallel(model)
model.to(device)
model.train()

# Inicializácia optimalizátora________________________
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

# Definícia triedy pre dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return self.encodings['input_ids'].shape[0]

    def __getitem__(self, i):
        return {key: tensor[i] for key, tensor in self.encodings.items()}

# Trénovanie modelu pre obe časti dát
for epoch in range(25,30):  # Vonkajší cyklus pre epochy
    start_time = time.time()
    cas = time.strftime('%Y-%m-%d-%H%M%S', time.localtime())
        
    batch_size = 128
    num_workers = 0
    encodings['input_ids'] = mlm(encodings['labels'].clone()) #Dynamické maskovanie
    dataset = Dataset(encodings)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    # Tréning
    loop = tqdm(dataloader, leave=True)
    model.train()
    for batch in loop:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss = loss.mean()  # Sumarizácia straty zo všetkých zariadení
        loss.backward()
        optimizer.step()
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
            
        with open(f"/home/projects/home/PureBPE/logs/epoch_{epoch}_encodings_{2}_{cas}", 'a') as file:
            file.write(f"{loss.item()}\n")
            file.flush()

        if time.time() - start_time >= 3600:
            model_save_path = f'/home/projects/home/PureBPE/models/PureBPE_epoch_{epoch}_encodings_{2}'
            model_without_dp = model.module if isinstance(model, torch.nn.DataParallel) else model
            model_without_dp.save_pretrained(model_save_path)
            start_time = time.time()
        
    model_save_path = f'/home/projects/home/PureBPE/models/PureBPE_epoch_{epoch}_encodings_{2}'
    model_without_dp = model.module if isinstance(model, torch.nn.DataParallel) else model
    model_without_dp.save_pretrained(model_save_path)
    start_time = time.time()
    gc.collect()

